{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mnist数据集逻辑回归分类任务"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "需要的工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/shoremei/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-1-cf63bf3938a4>:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting data/train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting data/t10k-images-idx3-ubyte.gz\n",
      "Extracting data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "done\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import tensorflow as tf\n",
    "\n",
    "mnist = input_data.read_data_sets('data/', one_hot=True)\n",
    "print(\"done\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "设置我们的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "numClasses = 10\n",
    "inputSize = 784  \n",
    "hidden1_size = 300\n",
    "trainingIterations = 3000\n",
    "batchSize = 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "指定好x和y的大小"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = tf.placeholder(tf.float32, shape = [None, inputSize])\n",
    "y = tf.placeholder(tf.float32, shape = [None, numClasses])\n",
    "keep_prob = tf.placeholder(tf.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参数初始化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n"
     ]
    }
   ],
   "source": [
    "W1 = tf.Variable(tf.truncated_normal([inputSize, hidden1_size], stddev=0.1))\n",
    "B1 = tf.Variable(tf.constant(0.1), [hidden1_size])\n",
    "W2 = tf.Variable(tf.zeros([hidden1_size, numClasses]))\n",
    "B2 = tf.Variable(tf.zeros([numClasses]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构造模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-5-636baefb9270>:2: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
      "WARNING:tensorflow:From /home/shoremei/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    }
   ],
   "source": [
    "hidden1 = tf.nn.relu(tf.matmul(X, W1)+B1)\n",
    "hidden1_drop = tf.nn.dropout(hidden1, keep_prob)\n",
    "\n",
    "y_pred = tf.nn.softmax(tf.matmul(hidden1_drop, W2)+B2)\n",
    "\n",
    "# 定义损失函数与优化器\n",
    "cross_entropy = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred), reduction_indices=[1]))\n",
    "\n",
    "train_op = tf.train.AdagradOptimizer(0.3).minimize(cross_entropy)\n",
    "\n",
    "# 计算准确率\n",
    "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_pred, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.Session()\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "迭代计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0, training accuracy 1\n",
      "step 499, training accuracy 0.98\n",
      "step 998, training accuracy 1\n",
      "step 1497, training accuracy 1\n",
      "step 1996, training accuracy 1\n",
      "step 2495, training accuracy 0.99\n",
      "step 2994, training accuracy 0.99\n"
     ]
    }
   ],
   "source": [
    "for i in range(trainingIterations):\n",
    "    batchInput, batchLabels = mnist.train.next_batch(batchSize)\n",
    "    \n",
    "#     train_op.run({X: batchInput, y: batchLabels, keep_prob: 0.75})\n",
    "    trainingLoss = sess.run([train_op], feed_dict={X: batchInput, y: batchLabels, keep_prob: 0.75})\n",
    "    \n",
    "    if i%499 == 0:\n",
    "        train_accuracy = accuracy.eval(session=sess, feed_dict={X: batchInput, y: batchLabels, keep_prob: 0.75})\n",
    "        print (\"step %d, training accuracy %g\"%(i, train_accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test accuracy 0.97\n"
     ]
    }
   ],
   "source": [
    "batch = mnist.test.next_batch(batchSize)\n",
    "testAccuracy = sess.run(accuracy, feed_dict={X: batch[0], y: batch[1], keep_prob: 0.75})\n",
    "print (\"test accuracy %g\"%(testAccuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
